{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 附一 基于 LangChain 自定义 LLM"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LangChain 为基于 LLM 开发自定义应用提供了高效的开发框架，便于开发者迅速地激发 LLM 的强大能力，搭建 LLM 应用。LangChain 也同样支持多种大模型，内置了 OpenAI、LLAMA 等大模型的调用接口。但是，LangChain 并没有内置所有大模型，它通过允许用户自定义 LLM 类型，来提供强大的可扩展性。\n",
    "\n",
    "在本部分，我们以智谱为例，讲述如何基于 LangChain 自定义 LLM，让我们基于 LangChain 搭建的应用能够支持国内平台。\n",
    "\n",
    "本部分涉及相对更多 LangChain、大模型调用的技术细节，有精力同学可以学习部署，如无精力可以直接使用后续代码来支持调用。\n",
    "\n",
    "要实现自定义 LLM，需要定义一个自定义类继承自 LangChain 的 LLM 基类，然后定义两个函数：  \n",
    "① _generate 方法，接收一系列消息及其他参数，并返回输出；  \n",
    "② _stream 方法， 接收一系列消息及其他参数，并以流式返回结果。  \n",
    "\n",
    "首先我们导入所需的第三方库："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, Dict, Iterator, List, Optional\n",
    "from zhipuai import ZhipuAI\n",
    "from langchain_core.callbacks import (\n",
    "    CallbackManagerForLLMRun,\n",
    ")\n",
    "from langchain_core.language_models import BaseChatModel\n",
    "from langchain_core.messages import (\n",
    "    AIMessage,\n",
    "    AIMessageChunk,\n",
    "    BaseMessage,\n",
    "    SystemMessage,\n",
    "    ChatMessage,\n",
    "    HumanMessage\n",
    ")\n",
    "from langchain_core.messages.ai import UsageMetadata\n",
    "from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult\n",
    "import time"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "由于LangChain的消息类型是`HumanMessage`、`AIMessage`等格式，与一般模型接收的字典格式不一样，因此我们需要先定义一个将LangChain的格式转为字典的函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _convert_message_to_dict(message: BaseMessage) -> dict:\n",
    "    \"\"\"把LangChain的消息格式转为智谱支持的格式\n",
    "    Args:\n",
    "        message: The LangChain message.\n",
    "    Returns:\n",
    "        The dictionary.\n",
    "    \"\"\"\n",
    "    message_dict: Dict[str, Any] = {\"content\": message.content}\n",
    "    if (name := message.name or message.additional_kwargs.get(\"name\")) is not None:\n",
    "        message_dict[\"name\"] = name\n",
    "\n",
    "    # populate role and additional message data\n",
    "    if isinstance(message, ChatMessage):\n",
    "        message_dict[\"role\"] = message.role\n",
    "    elif isinstance(message, HumanMessage):\n",
    "        message_dict[\"role\"] = \"user\"\n",
    "    elif isinstance(message, AIMessage):\n",
    "        message_dict[\"role\"] = \"assistant\"\n",
    "    elif isinstance(message, SystemMessage):\n",
    "        message_dict[\"role\"] = \"system\"\n",
    "    else:\n",
    "        raise TypeError(f\"Got unknown type {message}\")\n",
    "    return message_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接着我们定义一个继承自 LLM 类的自定义 LLM 类："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 继承自 LangChain 的 BaseChatModel 类\n",
    "class ZhipuaiLLM(BaseChatModel):\n",
    "    \"\"\"自定义Zhipuai聊天模型。\n",
    "    \"\"\"\n",
    "    model_name: str = None\n",
    "    temperature: Optional[float] = None\n",
    "    max_tokens: Optional[int] = None\n",
    "    timeout: Optional[int] = None\n",
    "    stop: Optional[List[str]] = None\n",
    "    max_retries: int = 3\n",
    "    api_key: str | None = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "上述初始化涵盖了我们平时常用的参数，也可以根据实际需求与智谱的 API 加入更多的参数。\n",
    "\n",
    "接下来我们实现_generate方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def _generate(\n",
    "        self,\n",
    "        messages: List[BaseMessage],\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> ChatResult:\n",
    "        \"\"\"通过调用智谱API从而响应输入。\n",
    "\n",
    "        Args:\n",
    "            messages: 由messages列表组成的prompt\n",
    "            stop: 在模型生成的回答中有该字符串列表中的元素则停止响应\n",
    "            run_manager: 一个为LLM提供回调的运行管理器\n",
    "        \"\"\"\n",
    "        # 列表推导式 将 messages 的元素逐个转为智谱的格式\n",
    "        messages = [_convert_message_to_dict(message) for message in messages]\n",
    "        # 定义推理的开始时间\n",
    "        start_time = time.time()\n",
    "        # 调用 ZhipuAI 对处理消息\n",
    "        response = ZhipuAI(api_key=self.api_key).chat.completions.create(\n",
    "            model=self.model_name,\n",
    "            temperature=self.temperature,\n",
    "            max_tokens=self.max_tokens,\n",
    "            timeout=self.timeout,\n",
    "            stop=stop,\n",
    "            messages=messages\n",
    "        )\n",
    "        # 计算运行时间 由现在时间 time.time() 减去 开始时间start_time得到\n",
    "        time_in_seconds = time.time() - start_time\n",
    "        # 将返回的消息封装并返回\n",
    "        message = AIMessage(\n",
    "            content=response.choices[0].message.content, # 响应的结果\n",
    "            additional_kwargs={}, # 额外信息\n",
    "            response_metadata={\n",
    "                \"time_in_seconds\": round(time_in_seconds, 3), # 响应源数据 这里是运行时间 也可以添加其他信息\n",
    "            },\n",
    "            # 本次推理消耗的token\n",
    "            usage_metadata={\n",
    "                \"input_tokens\": response.usage.prompt_tokens, # 输入token\n",
    "                \"output_tokens\": response.usage.completion_tokens, # 输出token\n",
    "                \"total_tokens\": response.usage.total_tokens, # 全部token\n",
    "            },\n",
    "        )\n",
    "        generation = ChatGeneration(message=message)\n",
    "        return ChatResult(generations=[generation])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来我们实现另一核心的方法_stream，之前注释的代码本次不再注释："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    def _stream(\n",
    "        self,\n",
    "        messages: List[BaseMessage],\n",
    "        stop: Optional[List[str]] = None,\n",
    "        run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> Iterator[ChatGenerationChunk]:\n",
    "        \"\"\"通过调用智谱API返回流式输出。\n",
    "\n",
    "        Args:\n",
    "            messages: 由messages列表组成的prompt\n",
    "            stop: 在模型生成的回答中有该字符串列表中的元素则停止响应\n",
    "            run_manager: 一个为LLM提供回调的运行管理器\n",
    "        \"\"\"\n",
    "        messages = [_convert_message_to_dict(message) for message in messages]\n",
    "        response = ZhipuAI().chat.completions.create(\n",
    "            model=self.model_name,\n",
    "            stream=True, # 将stream 设置为 True 返回的是迭代器，可以通过for循环取值\n",
    "            temperature=self.temperature,\n",
    "            max_tokens=self.max_tokens,\n",
    "            timeout=self.timeout,\n",
    "            stop=stop,\n",
    "            messages=messages\n",
    "        )\n",
    "        start_time = time.time()\n",
    "        # 使用for循环存取结果\n",
    "        for res in response:\n",
    "            if res.usage: # 如果 res.usage 存在则存储token使用情况\n",
    "                usage_metadata = UsageMetadata(\n",
    "                    {\n",
    "                        \"input_tokens\": res.usage.prompt_tokens,\n",
    "                        \"output_tokens\": res.usage.completion_tokens,\n",
    "                        \"total_tokens\": res.usage.total_tokens,\n",
    "                    }\n",
    "                )\n",
    "            # 封装每次返回的chunk\n",
    "            chunk = ChatGenerationChunk(\n",
    "                message=AIMessageChunk(content=res.choices[0].delta.content)\n",
    "            )\n",
    "\n",
    "            if run_manager:\n",
    "                # This is optional in newer versions of LangChain\n",
    "                # The on_llm_new_token will be called automatically\n",
    "                run_manager.on_llm_new_token(res.choices[0].delta.content, chunk=chunk)\n",
    "            # 使用yield返回 结果是一个生成器 同样可以使用for循环调用\n",
    "            yield chunk\n",
    "        time_in_sec = time.time() - start_time\n",
    "        # Let's add some other information (e.g., response metadata)\n",
    "        # 最终返回运行时间\n",
    "        chunk = ChatGenerationChunk(\n",
    "            message=AIMessageChunk(content=\"\", response_metadata={\"time_in_sec\": round(time_in_sec, 3)}, usage_metadata=usage_metadata)\n",
    "        )\n",
    "        if run_manager:\n",
    "            # This is optional in newer versions of LangChain\n",
    "            # The on_llm_new_token will be called automatically\n",
    "            run_manager.on_llm_new_token(\"\", chunk=chunk)\n",
    "        yield chunk"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们还需要定义一下模型的描述方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "    @property\n",
    "    def _llm_type(self) -> str:\n",
    "        \"\"\"获取此聊天模型使用的语言模型类型。\"\"\"\n",
    "        return self.model_name\n",
    "\n",
    "    @property\n",
    "    def _identifying_params(self) -> Dict[str, Any]:\n",
    "        \"\"\"返回一个标识参数的字典。\n",
    "\n",
    "        该信息由LangChain回调系统使用，用于跟踪目的，使监视llm成为可能。\n",
    "        \"\"\"\n",
    "        return {\n",
    "            \"model_name\": self.model_name,\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过上述步骤，我们就可以基于 LangChain 定义智谱的调用方式了。我们将此代码封装在 zhipu_llm.py 文件中。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
